import json
from typing import Callable, Dict, List, Union

from pydantic import BaseModel, Field

from lagent.actions import ActionExecutor, AsyncActionExecutor, BaseAction
from lagent.agents.agent import Agent, AsyncAgent
from lagent.agents.aggregator import DefaultAggregator
from lagent.hooks import ActionPreprocessor
from lagent.llms import BaseLLM
from lagent.memory import Memory
from lagent.prompts.parsers.json_parser import JSONParser
from lagent.prompts.prompt_template import PromptTemplate
from lagent.schema import AgentMessage

# 工具选择提示模板
select_action_template = """你是一个可以调用外部工具的助手，可以使用的工具包括：
{action_info}
{output_format}
开始!"""

# 输出格式提示模板
output_format_template = """如果使用工具请遵循以下格式回复：
{function_format}

如果你已经知道了答案，或者你不需要工具，请遵循以下格式回复
{finish_format}"""


class ReAct(Agent):
    """
    ReAct代理类
    
    实现了Reasoning and Acting的智能代理模式，能够进行推理并执行相应的动作
    """

    def __init__(
        self,
        llm: Union[BaseLLM, Dict],
        actions: Union[BaseAction, List[BaseAction]],
        template: Union[PromptTemplate, str] = None,
        memory: Dict = dict(type=Memory),
        output_format: Dict = dict(type=JSONParser),
        aggregator: Dict = dict(type=DefaultAggregator),
        hooks: List = [dict(type=ActionPreprocessor)],
        finish_condition: Callable[[AgentMessage], bool] = lambda m: 'conclusion' in m.content
        or 'conclusion' in m.formatted,
        max_turn: int = 5,
        **kwargs
    ):
        """
        初始化ReAct代理
        
        参数:
            llm: 语言模型配置或实例
            actions: 可用的动作列表
            template: 提示模板
            memory: 记忆配置
            output_format: 输出格式化器配置
            aggregator: 消息聚合器配置
            hooks: 钩子函数列表
            finish_condition: 完成条件判断函数
            max_turn: 最大对话轮次
        """
        self.max_turn = max_turn
        self.finish_condition = finish_condition
        # 初始化动作执行器
        self.actions = ActionExecutor(actions=actions, hooks=hooks)
        # 初始化动作选择代理
        self.select_agent = Agent(
            llm=llm,
            template=template.format(
                action_info=json.dumps(self.actions.description()), 
                output_format=output_format.format_instruction()
            ),
            output_format=output_format,
            memory=memory,
            aggregator=aggregator,
            hooks=hooks,
        )
        super().__init__(**kwargs)

    def forward(self, message: AgentMessage, session_id=0, **kwargs) -> AgentMessage:
        """
        处理输入消息
        
        实现推理-动作循环，直到达到完成条件或最大轮次
        """
        for _ in range(self.max_turn):
            # 选择动作
            message = self.select_agent(message, session_id=session_id, **kwargs)
            # 检查是否完成
            if self.finish_condition(message):
                return message
            # 执行动作
            message = self.actions(message, session_id=session_id)
        return message


class AsyncReAct(AsyncAgent):
    """
    异步ReAct代理类
    
    ReAct代理的异步实现版本
    """

    def __init__(
        self,
        llm: Union[BaseLLM, Dict],
        actions: Union[BaseAction, List[BaseAction]],
        template: Union[PromptTemplate, str] = None,
        memory: Dict = dict(type=Memory),
        output_format: Dict = dict(type=JSONParser),
        aggregator: Dict = dict(type=DefaultAggregator),
        hooks: List = [dict(type=ActionPreprocessor)],
        finish_condition: Callable[[AgentMessage], bool] = lambda m: 'conclusion' in m.content
        or 'conclusion' in m.formatted,
        max_turn: int = 5,
        **kwargs
    ):
        """初始化异步ReAct代理"""
        self.max_turn = max_turn
        self.finish_condition = finish_condition
        # 初始化异步动作执行器
        self.actions = AsyncActionExecutor(actions=actions, hooks=hooks)
        # 初始化异步动作选择代理
        self.select_agent = AsyncAgent(
            llm=llm,
            template=template.format(
                action_info=json.dumps(self.actions.description()), 
                output_format=output_format.format_instruction()
            ),
            output_format=output_format,
            memory=memory,
            aggregator=aggregator,
            hooks=hooks,
        )
        super().__init__(**kwargs)

    async def forward(self, message: AgentMessage, session_id=0, **kwargs) -> AgentMessage:
        """异步处理输入消息"""
        for _ in range(self.max_turn):
            # 异步选择动作
            message = await self.select_agent(message, session_id=session_id, **kwargs)
            if self.finish_condition(message):
                return message
            # 异步执行动作
            message = await self.actions(message, session_id=session_id)
        return message


# 示例代码部分
if __name__ == '__main__':
    import asyncio
    from lagent.llms import GPTAPI, AsyncGPTAPI

    # 定义动作调用格式
    class ActionCall(BaseModel):
        """动作调用模型"""
        name: str = Field(description='调用的函数名称')
        parameters: Dict = Field(description='调用函数的参数')

    # 定义动作格式
    class ActionFormat(BaseModel):
        """动作格式模型"""
        thought_process: str = Field(
            description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。'
        )
        action: ActionCall = Field(description='当前步骤需要执行的操作，包括函数名称和参数。')

    # 定义完成格式
    class FinishFormat(BaseModel):
        """完成格式模型"""
        thought_process: str = Field(
            description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。'
        )
        conclusion: str = Field(description='总结当前的搜索结果，回答问题。')

    # 创建提示模板和输出格式化器
    prompt_template = PromptTemplate(select_action_template)
    output_format = JSONParser(
        output_format_template, 
        function_format=ActionFormat, 
        finish_format=FinishFormat
    )

    # 创建同步代理实例并测试
    agent = ReAct(
        llm=dict(
            type=GPTAPI,
            model_type='gpt-4o-2024-05-13',
            max_new_tokens=4096,
            proxies=dict(),
            retry=1000,
        ),
        template=prompt_template,
        output_format=output_format,
        aggregator=dict(type='lagent.agents.aggregator.DefaultAggregator'),
        actions=[dict(type='lagent.actions.PythonInterpreter')],
    )
    response = agent(AgentMessage(sender='user', content='用 Python 计算一下 3 ** 5'))
    print(response)
    response = agent(AgentMessage(sender='user', content=' 2 ** 5 呢'))
    print(response)

    # 创建异步代理实例并测试
    async_agent = AsyncReAct(
        llm=dict(
            type=AsyncGPTAPI,
            model_type='gpt-4o-2024-05-13',
            max_new_tokens=4096,
            proxies=dict(),
            retry=1000,
        ),
        template=prompt_template,
        output_format=output_format,
        aggregator=dict(type='lagent.agents.aggregator.DefaultAggregator'),
        actions=[dict(type='lagent.actions.AsyncPythonInterpreter')],
    )
    response = asyncio.run(async_agent(
        AgentMessage(sender='user', content='用 Python 计算一下 3 ** 5')
    ))
    print(async_agent.state_dict())
